//-*-C++-*-
/***************************************************************************
 *
 *   Copyright (C) 2016 by Andrew Jameson and Willem van Straten
 *   Licensed under the Academic Free License version 2.1
 *
 ***************************************************************************/

#include "dsp/Transformation.h"
#include "dsp/WeightedTimeSeries.h"
#include "dsp/BitSeries.h"
#include "dsp/Memory.h"
#include "EventEmitter.h"

#ifndef __SpectralKurtosis_h
#define __SpectralKurtosis_h

#define ZAP_ALL  0
#define ZAP_SKFB 1
#define ZAP_FSCR 2
#define ZAP_TSCR 3

namespace dsp {

  class SKLimits;

  //! Perform Spectral Kurtosis on Input Timeseries, creating output Time Series
  /*! Output will be in time, frequency, polarization order */

  class SpectralKurtosis: public Transformation<TimeSeries,TimeSeries> {

  public:

    //! Default constructor
    SpectralKurtosis ();

    //! Destructor
    ~SpectralKurtosis ();

    bool get_order_supported (TimeSeries::Order order) const;

    //! Load configuration from YAML filename
    void load_configuration (const std::string& filename);

    void set_M (unsigned _M) { resolution[0].set_M(_M); }
    void set_M (const std::vector<unsigned>&);

    //! Set the number of overlapping regions per time sample
    void set_noverlap (unsigned _nover) { resolution[0].set_noverlap(_nover); }
    void set_noverlap (const std::vector<unsigned>&);

    //! Set the RFI thresholds with the specified factor
    void set_thresholds (float _std_devs);
    void set_thresholds (const std::vector<float>&);

    //! Evaluate SK for every time sample and frequency channel
    void set_detect_time_freq (bool flag=true) { detect_time_freq = flag; }

    //! Evaluate SK for every frequency channel after integrating over all time samples
    void set_detect_freq (bool flag=true) { detect_freq = flag; }

    //! Evaluate SK for every time sample after integrating over all frequency channels
    void set_detect_time (bool flag=true) { detect_time = flag; }

    //! On detecting an outlier, omit sample from any future computation of SK
    /*! This applies when computing SK for multiple resolutions and when integrating over time or frequency. */
    void set_omit_outliers (bool flag=true) { omit_outliers = flag; }

    //! Get the number of frequency channels for which SK is computed
    unsigned get_nchan () const { return nchan; }

    //! Get the number of polarizations for which SK is computed
    unsigned get_npol () const { return sums_npol; }

    //! Set the channel range to conduct detection
    void set_channel_range (unsigned start, unsigned end);

    void reserve ();

    void prepare ();

    void prepare_output ();

    //! Get the time delay of this operation, if any, in seconds
    double get_delay_time () const override;
    
    //! The number of time samples used to calculate the SK statistic
    unsigned get_M () const
    { return resolution[0].get_M(); }

    //! The excision threshold in number of standard deviations
    unsigned get_excision_threshold () const
    { return resolution[0].get_std_devs(); }

    //! Total SK statistic for each poln/channel, post filtering
    void get_filtered_sum (std::vector<float>& sum) const
    { sum = filtered_sum; }

    //! Hits on filtered average for each channel
    void get_filtered_hits (std::vector<uint64_t>& hits) const
    { hits = filtered_hits; }

    //! Total SK statistic for each poln/channel, before filtering
    void get_unfiltered_sum (std::vector<float>& sum) const
    { sum = unfiltered_sum; }

    //! Hits on unfiltered SK statistic, same for each channel
    uint64_t get_unfiltered_hits () const { return unfiltered_hits; }

    //! The arrays will be reset when count_zapped is next called
    void reset_count () { unfiltered_hits = 0; }

    //! Engine used to perform computations on device other than CPU
    class Engine;

    void set_engine (Engine*);

    template<class T>
    class Reporter {
    public:
      virtual void operator() (T*, unsigned, unsigned, unsigned, unsigned) {};
    };

    // An event emitter that takes a data array, and the nchan, npol, ndat and ndim
    // associated with the data array
    EventEmitter<Reporter<float> > float_reporter;

    // This is for reporting the state of the bit zapmask
    EventEmitter<Reporter<unsigned char> > char_reporter;

    bool get_report () const { return report; }

    void set_report (bool _report) { report = _report; }

    //! Return true if the zero_DM_input attribute has been set
    bool has_zero_DM_input () const;
    virtual void set_zero_DM_input (TimeSeries* zero_DM_input);
    virtual const TimeSeries* get_zero_DM_input() const;
    virtual TimeSeries* get_zero_DM_input();

    // bool has_zero_DM_input_container () const;
    // virtual void set_zero_DM_input_container (const HasInput<TimeSeries> zero_DM_input_container&);
    // virtual const HasInput<TimeSeries>& get_zero_DM_input_container() const;
    // virtual HasInput<TimeSeries>& get_zero_DM_input_container();

    virtual void set_zero_DM_buffering_policy (BufferingPolicy* policy)
    { zero_DM_buffering_policy = policy; }

    bool has_zero_DM_buffering_policy() const
    { return zero_DM_buffering_policy; }

    BufferingPolicy* get_zero_DM_buffering_policy () const
    { return zero_DM_buffering_policy; }

    //! get the zero_DM flag
    bool get_zero_DM () const { return zero_DM; }

    //! set the zero_DM flag
    void set_zero_DM (bool _zero_DM) { zero_DM = _zero_DM; }

  protected:

    //! Perform the transformation on the input time series
    void transformation ();

    //! Interface to alternate processing engine (e.g. GPU)
    Reference::To<Engine> engine;

  private:

    void compute ();

    void detect ();
    void detect_tscr ();
    void detect_skfb (unsigned ires);
    void detect_fscr (unsigned ires);
    void count_zapped ();

    void mask ();
    void reset_mask ();

    void insertsk ();

    unsigned debugd = 1;

    class Count
    {
      public:
        uint64_t tested = 0;
        uint64_t zapped = 0;
        double fraction_zapped () const { return (tested) ? double(zapped)/tested : 0.0; }
        const Count& operator+= (const Count& that) { tested+=that.tested; zapped+=that.zapped; return *this;}
    };

    class Resolution
    {
    private:
      
      //! frequency channels to be zapped
      mutable std::vector<bool> channels;

      //! lower and upper thresholds of excision limits
      mutable std::vector<float> thresholds;

      //! number of samples used in each SK estimate
      unsigned M = 128;

      //! number of instances integrated in each sample
      unsigned Nd = 1;

      //! Standard deviation used to compute thresholds
      float std_devs = 3.0;

      //! denominator of overlap factor
      unsigned noverlap = 1;

      //! sample offset to start of next overlapping M-sample block
      unsigned overlap_offset = 128;

      //! number of SK estimates produced
      uint64_t npart = 0;

      //! number of output time samples
      uint64_t output_ndat = 0;

      //! Count of time-frequency SK estimates tested and flagged as outliers
      Count SK_time_freq;

      //! Count of frequency-integrated SK estimates tested and flagged as outliers
      Count SK_time;

      //! compute the min and max SK thresholds
      void set_thresholds (bool verbose = false) const;

    public:

      //! Add a range of frequency channels to be zapped
      /*! from first to second inclusive; e.g. (0,1023) = 1024 channels */
      void add_include (const std::pair<unsigned, unsigned>&);

      //! Add a range of frequency channels not to be zapped
      /*! from first to second inclusive; e.g. (0,1023) = 1024 channels */
      void add_exclude (const std::pair<unsigned, unsigned>&);

      //! Get the channels to be zapped
      const std::vector<bool>& get_channels (unsigned nchan) const;
      
      //! number of samples used in each SK estimate
      unsigned get_M () const { return M; }
      void set_M (unsigned);

      //! number of samples used in each SK estimate
      unsigned get_Nd () const { return Nd; }
      void set_Nd (unsigned);

      //! number of std devs used to calculate excision limits
      float get_std_devs () const { return std_devs; }
      void set_std_devs (float);

      //! denominator of overlap factor
      /* blocks used to estimate SK overlap by (noverlap-1)/noverlap */
      unsigned get_noverlap () const { return noverlap; }
      void set_noverlap (unsigned);

      //! ensure that noverlap divides M and compute overlap_offset, npart, and output_ndat
      void prepare (uint64_t ndat = 0);

      //! number of time samples offset between consecutive (possibly overlapping) SK blocks
      unsigned get_overlap_offset () const { return overlap_offset; }

      //! number of SK estimates produced
      uint64_t get_npart () const { return npart; }

      //! number of output time samples flagged
      uint64_t get_output_ndat () const { return output_ndat; }

      //! ensure that this shares boundaries with that
      void compatible (Resolution& that);

      //! lower and upper thresholds of excision limits
      const std::vector<float>& get_thresholds () const;
      
      //! ranges of frequency channels to be zapped
      std::vector< std::pair<unsigned,unsigned> > include;

      //! ranges of frequency channels not to be zapped
      std::vector< std::pair<unsigned,unsigned> > exclude;

      //! Increment counts of time-frequency SK estimates tested and flagged as outliers
      void increment_time_freq (const Count&);
      const Count& get_count_time_freq () const;

      //! Increment counts of frequency-integrated SK estimates tested and flagged as outliers
      void increment_time (const Count&);
      const Count& get_count_time () const;
    };

    std::vector<Resolution> resolution;

    void resize_resolution (unsigned);

    //! integrate the S1 and S2 sum to new M and noverlap
    void tscrunch_sums (Resolution& from, Resolution& to);

    // for sorting by M
    static bool by_M (const Resolution& A, const Resolution& B);

    unsigned nchan = 0;
    unsigned npol = 0;
    unsigned ndim = 0;
    unsigned integrated_Nd = 1;

    unsigned sums_npol = 0;

    //! S1 and S2 sums
    Reference::To<WeightedTimeSeries> sums;

    //! Zap mask
    Reference::To<BitSeries> zapmask;

    //! Upper and lower bounds on SK for each Mprime count encountered
    std::vector<SKLimits*> dynamic_limits;
    const SKLimits* get_limits (unsigned count, float std_devs);

    //! Total SK statistic for each poln/channel, post filtering
    std::vector<float> filtered_sum;

    //! Hits on filtered average for each channel
    std::vector<uint64_t> filtered_hits;

    //! Total SK statistic for each poln/channel, before filtering
    std::vector<float> unfiltered_sum;

    //! Hits on unfiltered SK statistic, same for each channel
    uint64_t unfiltered_hits = 0;

    //! Count of time-integrated SK estimates tested and flagged as outliers
    Count SK_freq;

    //! flags for detection types
    bool detect_time_freq = true;
    bool detect_freq = true; // after tscrunch
    bool detect_time = true; // after fscrunch

    //! On detecting an outlier, omit sample (S1, S2, and count) from any future computation of SK.
    /*! This applies only when computing SK on multiple timescales and/or when detect_freq is true. */
    bool omit_outliers = true;

    bool prepared = false;

    //! flag that indicates whether or not to report intermediate data products
    //! via the *_report EventEmitter objects.
    bool report = false;

    // //! Input TimeSeries that has not been dedispersed in some previous operation.
    // Reference::To<dsp::TimeSeries> zero_DM_input;

    //! HasInput continaer for zero_DM_input TimeSeries
    HasInput<TimeSeries> zero_DM_input_container;

    Reference::To<BufferingPolicy> zero_DM_buffering_policy;

    bool zero_DM = false;
    double delay_time = 0.0;
  };

  class SpectralKurtosis::Engine : public Reference::Able
  {
  public:

      virtual void setup () = 0;

      virtual void compute (const TimeSeries* input, TimeSeries* output,
                            TimeSeries *output_tscr, unsigned tscrunch) = 0;

      virtual void reset_mask (BitSeries* output) = 0;

      virtual void detect_ft (const TimeSeries* input, BitSeries* output,
                              float upper_thresh, float lower_thresh) = 0;

      virtual void detect_fscr (const TimeSeries* input, BitSeries* output,
                                const float mu2, const float std_devs,
                                unsigned schan, unsigned echan) = 0;

      virtual void detect_tscr (const TimeSeries* input,
                                const TimeSeries * input_tscr,
                                BitSeries* output,
                                float upper, float lower) = 0;

      virtual int count_mask (const BitSeries* output) = 0;

      virtual float * get_estimates (const TimeSeries* input) = 0;

      virtual unsigned char * get_zapmask (const BitSeries* input) = 0;

      virtual void mask (BitSeries* mask, const TimeSeries * in, TimeSeries* out, unsigned M) = 0;

      virtual void insertsk (const TimeSeries* input, TimeSeries* out, unsigned M) = 0;

  };
}

#endif
